package resolvers

import (
	"context"
	"sort"
	"time"

	"github.com/graph-gophers/graphql-go"
	"github.com/pkg/errors"
	cveConverter "github.com/stackrox/rox/central/cve/converter/utils"
	"github.com/stackrox/rox/central/graphql/resolvers/embeddedobjs"
	"github.com/stackrox/rox/central/graphql/resolvers/loaders"
	"github.com/stackrox/rox/central/metrics"
	"github.com/stackrox/rox/central/node/mappings"
	v1 "github.com/stackrox/rox/generated/api/v1"
	"github.com/stackrox/rox/generated/storage"
	"github.com/stackrox/rox/pkg/cve"
	pkgMetrics "github.com/stackrox/rox/pkg/metrics"
	"github.com/stackrox/rox/pkg/protocompat"
	"github.com/stackrox/rox/pkg/search"
	"github.com/stackrox/rox/pkg/search/predicate"
	"github.com/stackrox/rox/pkg/search/scoped"
	"github.com/stackrox/rox/pkg/utils"
)

var (
	nodeVulnerabilityPredicateFactory = predicate.NewFactory("vulnerability", &storage.NodeVulnerability{})
)

func init() {
	schema := getBuilder()
	utils.Must(
		// Resolvers for fields in storage.NodeComponent are autogenerated and located in generated.go
		// NOTE: This list is and should remain alphabetically ordered
		schema.AddExtraResolvers("NodeComponent", []string{
			"fixedIn: String!",
			"lastScanned: Time",
			"location(query: String): String!",
			"nodes(query: String, scopeQuery: String, pagination: Pagination): [Node!]!",
			"nodeCount(query: String, scopeQuery: String): Int!",
			"nodeVulnerabilities(query: String, scopeQuery: String, pagination: Pagination): [NodeVulnerability]!",
			"nodeVulnerabilityCount(query: String, scopeQuery: String): Int!",
			"nodeVulnerabilityCounter(query: String): VulnerabilityCounter!",
			"plottedNodeVulnerabilities(query: String): PlottedNodeVulnerabilities!",
			"source: String!",
			"topNodeVulnerability: NodeVulnerability",
			"unusedVarSink(query: String): Int",
		}),

		schema.AddQuery("nodeComponent(id: ID): NodeComponent"),
		schema.AddQuery("nodeComponents(query: String, scopeQuery: String, pagination: Pagination): [NodeComponent!]!"),
		schema.AddQuery("nodeComponentCount(query: String): Int!"),
	)
}

// NodeComponentResolver represents a generic resolver of node component fields.
// NOTE: This list is and should remain alphabetically ordered
type NodeComponentResolver interface {
	FixedIn(ctx context.Context) string
	Id(ctx context.Context) graphql.ID
	LastScanned(ctx context.Context) (*graphql.Time, error)
	Location(ctx context.Context, args RawQuery) (string, error)
	Name(ctx context.Context) string
	Nodes(ctx context.Context, args PaginatedQuery) ([]*nodeResolver, error)
	NodeCount(ctx context.Context, args RawQuery) (int32, error)
	NodeVulnerabilities(ctx context.Context, args PaginatedQuery) ([]NodeVulnerabilityResolver, error)
	NodeVulnerabilityCount(ctx context.Context, args RawQuery) (int32, error)
	NodeVulnerabilityCounter(ctx context.Context, args RawQuery) (*VulnerabilityCounterResolver, error)
	OperatingSystem(ctx context.Context) string
	PlottedNodeVulnerabilities(ctx context.Context, args RawQuery) (*PlottedNodeVulnerabilitiesResolver, error)
	Priority(ctx context.Context) int32
	RiskScore(ctx context.Context) float64
	Source(ctx context.Context) string
	TopNodeVulnerability(ctx context.Context) (NodeVulnerabilityResolver, error)
	UnusedVarSink(ctx context.Context, args RawQuery) *int32
	Version(ctx context.Context) string
}

// NodeComponent returns a node component based on an input id (name:version)
func (resolver *Resolver) NodeComponent(ctx context.Context, args IDQuery) (NodeComponentResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "NodeComponent")

	if err := readNodes(ctx); err != nil {
		return nil, err
	}
	loader, err := loaders.GetNodeComponentLoader(ctx)
	if err != nil {
		return nil, err
	}

	ret, err := loader.FromID(ctx, string(*args.ID))
	return resolver.wrapNodeComponentWithContext(ctx, ret, true, err)
}

// NodeComponents returns node components that match the input query.
func (resolver *Resolver) NodeComponents(ctx context.Context, q PaginatedQuery) ([]NodeComponentResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "NodeComponents")

	if err := readNodes(ctx); err != nil {
		return nil, err
	}
	query, err := q.AsV1QueryOrEmpty()
	if err != nil {
		return nil, err
	}
	loader, err := loaders.GetNodeComponentLoader(ctx)
	if err != nil {
		return nil, err
	}

	comps, err := loader.FromQuery(ctx, query)
	componentResolvers, err := resolver.wrapNodeComponentsWithContext(ctx, comps, err)
	if err != nil {
		return nil, err
	}

	ret := make([]NodeComponentResolver, 0, len(componentResolvers))
	for _, res := range componentResolvers {
		ret = append(ret, res)
	}
	return ret, nil
}

// NodeComponentCount returns count of node components that match the input query
func (resolver *Resolver) NodeComponentCount(ctx context.Context, args RawQuery) (int32, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "NodeComponentCount")

	if err := readNodes(ctx); err != nil {
		return 0, err
	}
	query, err := args.AsV1QueryOrEmpty()
	if err != nil {
		return 0, err
	}
	loader, err := loaders.GetNodeComponentLoader(ctx)
	if err != nil {
		return 0, err
	}

	return loader.CountFromQuery(ctx, query)
}

/*
Utility Functions
*/

func (resolver *nodeComponentResolver) nodeComponentScopeContext(ctx context.Context) context.Context {
	if ctx == nil {
		err := utils.ShouldErr(errors.New("argument 'ctx' is nil"))
		if err != nil {
			log.Error(err)
		}
	}
	if resolver.ctx == nil {
		resolver.ctx = ctx
	}
	return scoped.Context(resolver.ctx, scoped.Scope{
		Level: v1.SearchCategory_NODE_COMPONENTS,
		IDs:   []string{resolver.data.GetId()},
	})
}

func (resolver *nodeComponentResolver) nodeComponentQuery() *v1.Query {
	return search.NewQueryBuilder().AddExactMatches(search.ComponentID, resolver.data.GetId()).ProtoQuery()
}

func getNodeCVEResolvers(ctx context.Context, root *Resolver, os string, vulns []*storage.NodeVulnerability, query *v1.Query) ([]NodeVulnerabilityResolver, error) {
	query, _ = search.FilterQueryWithMap(query, mappings.NodeVulnerabilityOptionsMap)
	predicate, err := nodeVulnerabilityPredicateFactory.GeneratePredicate(query)
	if err != nil {
		return nil, err
	}

	// Use the nodes to map CVEs to the nodes and components.
	idToVals := make(map[string]*nodeCVEResolver)
	for _, vuln := range vulns {
		if !predicate.Matches(vuln) {
			continue
		}
		id := cve.ID(vuln.GetCveBaseInfo().GetCve(), os)
		if _, exists := idToVals[id]; !exists {
			converted := cveConverter.NodeVulnerabilityToNodeCVE(os, vuln)
			resolver, err := root.wrapNodeCVE(converted, true, nil)
			if err != nil {
				return nil, err
			}
			resolver.ctx = embeddedobjs.NodeVulnContext(ctx, vuln)
			idToVals[id] = resolver
		}
	}

	// For now, sort by ID.
	resolverObjs := make([]*nodeCVEResolver, 0, len(idToVals))
	for _, vuln := range idToVals {
		resolverObjs = append(resolverObjs, vuln)
	}
	if len(query.GetPagination().GetSortOptions()) == 0 {
		sort.SliceStable(resolverObjs, func(i, j int) bool {
			return resolverObjs[i].data.GetId() < resolverObjs[j].data.GetId()
		})
	}
	nodeVulnResolvers := make([]NodeVulnerabilityResolver, 0, len(resolverObjs))
	for _, resolver := range resolverObjs {
		nodeVulnResolvers = append(nodeVulnResolvers, resolver)
	}
	return paginate(query.GetPagination(), nodeVulnResolvers, nil)
}

/*
Sub Resolver Functions
*/

// FixedIn returns the node component version that fixes all the fixable vulnerabilities in this component.
func (resolver *nodeComponentResolver) FixedIn(_ context.Context) string {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.NodeComponents, "FixedIn")
	return ""
}

// LastScanned is the last time the node component was scanned in a node.
func (resolver *nodeComponentResolver) LastScanned(ctx context.Context) (*graphql.Time, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.NodeComponents, "LastScanned")
	if resolver.ctx == nil {
		resolver.ctx = ctx
	}

	// Short path. Full image is embedded when image scan resolver is called.
	if scanTime := embeddedobjs.NodeComponentLastScannedFromContext(resolver.ctx); scanTime != nil {
		return &graphql.Time{Time: *scanTime}, nil
	}

	nodeLoader, err := loaders.GetNodeLoader(resolver.ctx)
	if err != nil {
		return nil, err
	}

	componentQuery := resolver.nodeComponentQuery()
	componentQuery.Pagination = &v1.QueryPagination{
		Limit:  1,
		Offset: 0,
		SortOptions: []*v1.QuerySortOption{
			{
				Field:    search.NodeScanTime.String(),
				Reversed: true,
			},
		},
	}

	nodes, err := nodeLoader.FromQuery(resolver.ctx, componentQuery)
	if err != nil || len(nodes) == 0 {
		return nil, err
	} else if len(nodes) > 1 {
		return nil, errors.New("multiple nodes matched for last scanned component query")
	}

	return protocompat.ConvertTimestampToGraphqlTimeOrError(nodes[0].GetScan().GetScanTime())
}

// Location of the node component.
func (resolver *nodeComponentResolver) Location(_ context.Context, _ RawQuery) (string, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.NodeComponents, "Location")
	return "Not Available", nil
}

// Nodes that contain the node component.
func (resolver *nodeComponentResolver) Nodes(ctx context.Context, args PaginatedQuery) ([]*nodeResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.NodeComponents, "Nodes")
	return resolver.root.Nodes(resolver.nodeComponentScopeContext(ctx), args)
}

// NodeCount is the number of nodes that contain the node component
func (resolver *nodeComponentResolver) NodeCount(ctx context.Context, args RawQuery) (int32, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.NodeComponents, "NodeCount")
	return resolver.root.NodeCount(resolver.nodeComponentScopeContext(ctx), args)
}

// NodeVulnerabilities contained in the node component
func (resolver *nodeComponentResolver) NodeVulnerabilities(ctx context.Context, args PaginatedQuery) ([]NodeVulnerabilityResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.NodeComponents, "NodeVulnerabilities")

	if resolver.ctx == nil {
		resolver.ctx = ctx
	}

	// Short path. Full node is embedded when node scan resolver is called.
	embeddedComponent := embeddedobjs.NodeComponentFromContext(resolver.ctx)
	if embeddedComponent == nil {
		return resolver.root.NodeVulnerabilities(resolver.nodeComponentScopeContext(ctx), args)
	}

	query, err := args.AsV1QueryOrEmpty()
	if err != nil {
		return nil, err
	}
	return getNodeCVEResolvers(resolver.ctx, resolver.root, resolver.data.GetOperatingSystem(), embeddedComponent.GetVulnerabilities(), query)
}

// NodeVulnerabilityCount resolves the number of node vulnerabilities contained in the node component
func (resolver *nodeComponentResolver) NodeVulnerabilityCount(ctx context.Context, args RawQuery) (int32, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.NodeComponents, "NodeVulnerabilityCount")
	return resolver.root.NodeVulnerabilityCount(resolver.nodeComponentScopeContext(ctx), args)
}

// NodeVulnerabilityCounter resolves the number of different types of node vulnerabilities contained in a node component
func (resolver *nodeComponentResolver) NodeVulnerabilityCounter(ctx context.Context, args RawQuery) (*VulnerabilityCounterResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.NodeComponents, "NodeVulnerabilityCounter")
	return resolver.root.NodeVulnerabilityCounter(resolver.nodeComponentScopeContext(ctx), args)
}

// PlottedNodeVulnerabilities returns the data required by top risky component scatter-plot on vuln mgmt dashboard
func (resolver *nodeComponentResolver) PlottedNodeVulnerabilities(ctx context.Context, args RawQuery) (*PlottedNodeVulnerabilitiesResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.NodeComponents, "PlottedNodeVulnerabilities")
	return resolver.root.PlottedNodeVulnerabilities(resolver.nodeComponentScopeContext(ctx), args)
}

// Source returns the source type of the node component
func (resolver *nodeComponentResolver) Source(_ context.Context) string {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.NodeComponents, "Source")
	return storage.SourceType_INFRASTRUCTURE.String()
}

// TopNodeVulnerability returns the first node component vulnerability with the top CVSS score
func (resolver *nodeComponentResolver) TopNodeVulnerability(ctx context.Context) (NodeVulnerabilityResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.NodeComponents, "TopNodeVulnerability")

	// Short path. Full node is embedded when node scan resolver is called.
	if embeddedComponent := embeddedobjs.NodeComponentFromContext(resolver.ctx); embeddedComponent != nil {
		var topVuln *storage.NodeVulnerability
		for _, vuln := range embeddedComponent.GetVulnerabilities() {
			if topVuln == nil || vuln.GetCvss() > topVuln.GetCvss() {
				topVuln = vuln
			}
		}
		if topVuln == nil {
			return nil, nil
		}
		return resolver.root.wrapNodeCVE(
			cveConverter.NodeVulnerabilityToNodeCVE(resolver.data.GetOperatingSystem(), topVuln), true, nil,
		)
	}

	return resolver.root.TopNodeVulnerability(resolver.nodeComponentScopeContext(ctx), RawQuery{})
}

// UnusedVarSink represents a query sink
func (resolver *nodeComponentResolver) UnusedVarSink(_ context.Context, _ RawQuery) *int32 {
	return nil
}
